{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction to Handwritten Digit Classification\n",
    "\n",
    "*Copyright (c) 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.*\n",
    "\n",
    "Computer Vision (CV) refers to enabling computers to obtain meaningful information from images, videos, or other visual inputs. It is a fundamental and important field in artificial intelligence. In CV, handwritten digit classification is a relatively basic task. It is trained and tested on the MNIST dataset \\[1\\] to verify whether the model has the basic ability of CV.\n",
    "\n",
    "The MNIST dataset contains handwritten digits as shown in the figure below. MNIST contains a total of 10 categories from 0-9, and each digit is a grayscale image of 28\\*28 pixels. There are 60,000 images in the training set and 10,000 images in the test set. Suppose we design a model that can be used for image classification, then we can test the classification ability of the model on the MNIST dataset.\n",
    "\n",
    "![mnist-example](mnist_example.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST Classification Using VSQL Model\n",
    "\n",
    "### Data Encoding\n",
    "\n",
    "In the handwritten digit classification problem, the input is a picture of a handwritten digit and the output is the category corresponding to the picture (i.e., the digits 0-9). And since quantum computers deal with inputs that are quantum states, we need to encode the picture into a quantum state. Here, we first represent a picture using a two-dimensional matrix. This matrix is then expanded into a 1D vector, and the length of the vector is padded to an integer power of 2 by padding with zeros. The vector is then normalized to obtain a quantum state that can be processed by a quantum computer.\n",
    "\n",
    "\n",
    "### Introduction to the VSQL Model\n",
    "\n",
    "Variational shadow quantum learning (VSQL) is a hybrid quantum-classical algorithm under the framework of supervised learning. It uses the parameterized quantum circuit (PQC) and the classical shadow. Unlike the common variational quantum algorithm (VQA), VSQL only obtains local features from the subspace rather than from the whole Hilbert space where the quantum states are formed.\n",
    "\n",
    "The schematic diagram of the VSQL model is as follows.\n",
    "\n",
    "![vsql-model](vsql_model.png)\n",
    "\n",
    "The input to the VSQL process is a quantum state. For the input quantum state, a local parameterized quantum circuit is iteratively applied and measured to obtain local shadow features. Then all the obtained shadow features are calculated using the classical neural network and the predicted labels are obtained.\n",
    "\n",
    "### Workflow\n",
    "\n",
    "Based on the above principles, we only need to train the VSQL model using the MNIST dataset to obtain a converged model. The model can be used to classify handwritten digits. The training process of the model is as follows.\n",
    "\n",
    "![vsql-pipeline](vsql_pipeline_en.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to Use\n",
    "\n",
    "### Predict Using the Model\n",
    "\n",
    "Here, we have given a trained model that can be used directly for the prediction of 0 and 1 images. Just make the corresponding configuration in the `example.toml` configuration file and enter the command `python vsql_classification.py --config example.toml` to test the input images with the trained VSQL model.\n",
    "\n",
    "### Online Demo\n",
    "\n",
    "Here, we give a version of the online demo that can be tested online. First define the contents of the configuration file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_toml = r\"\"\"\n",
    "# The overall configuration file of the model.\n",
    "# Enter the current task, which can be 'train' or 'test', representing training and prediction respectively. Here we use test, indicating that we want to make a prediction.\n",
    "task = 'test'\n",
    "# The file path of the image to be predicted.\n",
    "image_path = 'data_0.png'\n",
    "# Whether the image path above is a folder or not. For folder paths, we will predict all image files inside the folder. This way you can test multiple images at once.\n",
    "is_dir = false\n",
    "# The file path of the trained model parameter file.\n",
    "model_path = 'vsql.pdparams'\n",
    "# The number of qubits that the quantum circuit contains.\n",
    "num_qubits = 10\n",
    "# The number of qubits that the shadow circuit contains.\n",
    "num_shadow = 2\n",
    "# Circuit depth.\n",
    "depth = 1\n",
    "# The class to be predicted by the model. Here, 0 and 1 are classified.\n",
    "classes = [0, 1]\n",
    "\"\"\"\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next is the code for the prediction section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For the input image, the model has 89.22% confidence that it is 0, and 10.78% confidence that it is 1.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
    "\n",
    "import toml\n",
    "from paddle_quantum.qml.vsql import train, inference\n",
    "\n",
    "config = toml.loads(test_toml)\n",
    "task = config.pop('task')\n",
    "if task == 'train':\n",
    "    train(**config)\n",
    "elif task == 'test':\n",
    "    prediction, prob = inference(**config)\n",
    "    if config['is_dir']:\n",
    "        print(f\"The prediction results of the input pictures are {str(prediction)[1:-1]} respectively.\")\n",
    "    else:\n",
    "        prob = prob[0]\n",
    "        msg = 'For the input image, the model has'\n",
    "        for idx, item in enumerate(prob):\n",
    "            if idx == len(prob) - 1:\n",
    "                msg += 'and'\n",
    "            label = config['classes'][idx]\n",
    "            msg += f' {item:3.2%} confidence that it is {label:d}'\n",
    "            msg += '.' if idx == len(prob) - 1 else ', '\n",
    "        print(msg)\n",
    "else:\n",
    "    raise ValueError(\"Unknown task, it can be train or test.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we only need to modify the image path in the configuration file, and then run the entire code to quickly test other images."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Note\n",
    "\n",
    "The model we provide is a binary classification model that can only be used to distinguish handwritten digits 0 and 1. For other classification tasks, it needs to be retrained.\n",
    "\n",
    "### Dataset Structure\n",
    "\n",
    "If you want to use a custom dataset for training, you just need to prepare the dataset according to the rules. Prepare `train.txt` and `test.txt` in the dataset folder, and `dev.txt` if a validation set is needed. Use one line in each file to represent one piece of data. Each line contains the file path and label of the image, separated by tabs.\n",
    "\n",
    "### Introduction to the Configuration File\n",
    "\n",
    "In `test.toml`, there is a complete reference to the configuration files needed for testing. In `train.toml`, there is a complete reference to the configuration files needed for training. You can use the configuration file to quickly use the model to train and test."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Citation\n",
    "\n",
    "```tex\n",
    "@inproceedings{li2021vsql,\n",
    "  title={VSQL: Variational shadow quantum learning for classification},\n",
    "  author={Li, Guangxi and Song, Zhixin and Wang, Xin},\n",
    "  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},\n",
    "  volume={35},\n",
    "  number={9},\n",
    "  pages={8357--8365},\n",
    "  year={2021}\n",
    "}\n",
    "```\n",
    "\n",
    "## Reference\n",
    "\n",
    "\\[1\\] \"THE MNIST DATABASE of handwritten digits\". Yann LeCun, Courant Institute, NYU Corinna Cortes, Google Labs, New York Christopher J.C. Burges, Microsoft Research, Redmond."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15 (default, Nov 10 2022, 12:46:26) \n[Clang 14.0.6 ]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
